# This file includes code originally from the Segment and Track Anything repository:
# https://github.com/z-x-yang/Segment-and-Track-Anything
# Licensed under the AGPL-3.0 License. See THIRD_PARTY_LICENSES.md for details.

import numpy as np

from ...utils.image import one_hot_mask
from ..layers.basic import seq_to_2d
from .aot_engine import AOTEngine, AOTInferEngine


class DeAOTEngine(AOTEngine):
    def __init__(
        self,
        aot_model,
        gpu_id=0,
        long_term_mem_gap=9999,
        short_term_mem_skip=1,
        layer_loss_scaling_ratio=2.0,
        max_len_long_term=9999,
    ):
        super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_len_long_term)
        self.layer_loss_scaling_ratio = layer_loss_scaling_ratio

    def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False):
        if curr_id_emb is None:
            if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1:
                curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num)
            else:
                curr_one_hot_mask = curr_mask
            curr_id_emb = self.assign_identity(curr_one_hot_mask)

        lstt_curr_memories = self.curr_lstt_output[1]
        lstt_curr_memories_2d = []
        for layer_idx in range(len(lstt_curr_memories)):
            curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[layer_idx]
            curr_id_k, curr_id_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb)
            lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][3] = (
                curr_id_k,
                curr_id_v,
            )
            local_curr_id_k = seq_to_2d(curr_id_k, self.enc_size_2d) if curr_id_k is not None else None
            local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d)
            lstt_curr_memories_2d.append(
                [
                    seq_to_2d(curr_k, self.enc_size_2d),
                    seq_to_2d(curr_v, self.enc_size_2d),
                    local_curr_id_k,
                    local_curr_id_v,
                ]
            )

        self.short_term_memories_list.append(lstt_curr_memories_2d)
        self.short_term_memories_list = self.short_term_memories_list[-self.short_term_mem_skip :]
        self.short_term_memories = self.short_term_memories_list[0]

        if self.frame_step - self.last_mem_step >= self.long_term_mem_gap:
            # skip the update of long-term memory or not
            if not skip_long_term_update:
                self.update_long_term_memory(lstt_curr_memories)
            self.last_mem_step = self.frame_step


class DeAOTInferEngine(AOTInferEngine):
    def __init__(
        self,
        aot_model,
        gpu_id=0,
        long_term_mem_gap=9999,
        short_term_mem_skip=1,
        max_aot_obj_num=None,
        max_len_long_term=9999,
    ):
        super().__init__(
            aot_model,
            gpu_id,
            long_term_mem_gap,
            short_term_mem_skip,
            max_aot_obj_num,
            max_len_long_term,
        )

    def add_reference_frame(self, img, mask, obj_nums, frame_step=-1):
        if isinstance(obj_nums, list):
            obj_nums = obj_nums[0]
        self.obj_nums = obj_nums
        aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
        while aot_num > len(self.aot_engines):
            new_engine = DeAOTEngine(
                self.AOT,
                self.gpu_id,
                self.long_term_mem_gap,
                self.short_term_mem_skip,
                max_len_long_term=self.max_len_long_term,
            )
            new_engine.eval()
            self.aot_engines.append(new_engine)

        separated_masks, separated_obj_nums = self.separate_mask(mask, obj_nums)
        img_embs = None
        for aot_engine, separated_mask, separated_obj_num in zip(self.aot_engines, separated_masks, separated_obj_nums):
            if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
                aot_engine.add_reference_frame(
                    img,
                    separated_mask,
                    obj_nums=[separated_obj_num],
                    frame_step=frame_step,
                    img_embs=img_embs,
                )
            else:
                aot_engine.update_short_term_memory(separated_mask)
            if img_embs is None:  # reuse image embeddings
                img_embs = aot_engine.curr_enc_embs

        self.update_size()
